分类
联系方式
  1. 新浪微博
  2. E-mail

RL Baselines3 Zoo ExperimentManager

介绍

ExperimentManager 是 RL Baselines3 Zoo 中的实验管理器,用于加载超参数,预处理,创建环境和 RL 模型,是 RL Baselines3 Zoo 中非常核心的一个类。

在这个类中调用了底层的 satble_baselines3 和 gym。

代码位于 utils\exp_manager.py。

使用方式

train.py 是 RL Baselines3 Zoo 中的训练脚本,底层基于 ExperimentManager 实现。

train.py 其实没做什么,主要是解析参数,主要逻辑都通过调用 ExperimentManager 实现:

exp_manager = ExperimentManager(
    args,
    # 省略一大堆参数
)

# Prepare experiment and launch hyperparameter optimization if needed
model = exp_manager.setup_experiment()

# Normal training
if model is not None:
    exp_manager.learn(model)
    exp_manager.save_trained_model(model)
else:
    exp_manager.hyperparameters_optimization()

有一个逻辑判断:如果模型不存在就进行训练,并且训练后保存;如果模型存在,则进行超参数调参优化。

属性

属性有很多,这里罗列一部分:

名称 类型 说明 备注
algo str 采用的算法
env_id str 采用的模型
normalize bool normalize 相关
normalize_kwargs Map normalize 相关

公有方法

setup_experiment 实验环境设置

加载超参数:

def setup_experiment(self) -> Optional[BaseAlgorithm]:
    """
    Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
    create the environment and possibly the model.

    :return: the initialized RL model
    """
    # 加载超参数
    hyperparams, saved_hyperparams = self.read_hyperparameters()
    # 超参数预处理
    hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)

    # 创建 Tensorbard 日志目录
    self.create_log_folder()
    # 创建回调
    self.create_callbacks()

    # Create env to have access to action space for action noise
    # 创建环境
    env = self.create_envs(self.n_envs, no_log=False)

    self._hyperparams = self._preprocess_action_noise(hyperparams, saved_hyperparams, env)

    # 如果是对已有模型进行连续学习,则加载之前训练的 agent
    if self.continue_training:
        model = self._load_pretrained_agent(self._hyperparams, env)
    # 如果是进行超参数调优,返回一个 null?这里的返回值是模型,也就是不返回模型?
    elif self.optimize_hyperparameters:
        return None
    # 训练新模型
    else:
        # Train an agent from scratch
        # ALGOS 是 ZOO 支持的所有算法,位于 utils\utils.py
        model = ALGOS[self.algo](
            env=env, # 传入环境
            tensorboard_log=self.tensorboard_log,  # 日志目录
            seed=self.seed, # 随机数
            verbose=self.verbose, # 是否开启话痨模式
            **self._hyperparams, # 把超参数传进算法
        )

    # 保存超参数
    self._save_config(saved_hyperparams)
    return model

超参数预处理:

hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)

learn 对模型进行强化学习

def learn(self, model: BaseAlgorithm) -> None:
    """
    :param model: an initialized RL model
    """
    # 解析模型输入参数
    kwargs = {}
    if self.log_interval > -1:
        kwargs = {"log_interval": self.log_interval}

    if len(self.callbacks) > 0:
        kwargs["callback"] = self.callbacks

    try:
        # 调用模型学习算法,看到也没传进去什么参数
        model.learn(self.n_timesteps, **kwargs)
    except KeyboardInterrupt:
        # this allows to save the model when interrupting training
        pass
    finally:
        # Release resources
        try:
            # 训练完毕释放资源
            model.env.close()
        except EOFError:
            pass

可以看到调模型的 learn 方法的时候传的参数很少,最多就 2 个。

前面处理的大量超参数,其实都是再模型创建(setup_experiment 实验环境设置)的时候传入的。

create_env 创建环境

def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) -> VecEnv:
    """
    Create the environment and wrap it if necessary.

    :param n_envs:
    :param eval_env: Whether is it an environment used for evaluation or not
                     用于支持自定义环境,自定义的时候传入一个包名需要动态加载
    :param no_log: Do not log training when doing hyperparameter optim
        (issue with writing the same file)
    :return: the vectorized environment, with appropriate wrappers
             返回值是环境
    """
    # Do not log eval env (issue with writing the same file)
    log_dir = None if eval_env or no_log else self.save_path

    monitor_kwargs = {}
    # Special case for GoalEnvs: log success rate too
    # 对某种环境进行特殊处理
    if "Neck" in self.env_id or self.is_robotics_env(self.env_id) or "parking-v0" in self.env_id:
        monitor_kwargs = dict(info_keywords=("is_success",))

    # On most env, SubprocVecEnv does not help and is quite memory hungry
    # therefore we use DummyVecEnv by default
    # make_vec_env 是 stable_baselines3.common.env_util 中提供的方法用于创建环境
    # 对于大多数环境来说,SubprocVecEnv 没有什么帮助,而且内存开销大
    # 因此我们默认采用 DummyVecEnv
    # 大多数传入属性都是从类属性中获取的
    env = make_vec_env(
        env_id=self.env_id,
        n_envs=n_envs,
        seed=self.seed,
        env_kwargs=self.env_kwargs,
        monitor_dir=log_dir,
        wrapper_class=self.env_wrapper,
        vec_env_cls=self.vec_env_class,
        vec_env_kwargs=self.vec_env_kwargs,
        monitor_kwargs=monitor_kwargs,
    )

    # Wrap the env into a VecNormalize wrapper if needed
    # and load saved statistics when present
    # 对环境进行了一个标准化,这块需要再看看
    env = self._maybe_normalize(env, eval_env)

    # Optional Frame-stacking,帧-栈是什么?
    if self.frame_stack is not None:
        n_stack = self.frame_stack
        env = VecFrameStack(env, n_stack)
        if self.verbose > 0:
            print(f"Stacking {n_stack} frames")

    # Wrap if needed to re-order channels
    # (switch from channel last to channel first convention)
    # 如果是图像相关,又封装了一层环境,封装到 VecTransposeImage 里面
    if is_image_space(env.observation_space) and not is_image_space_channels_first(env.observation_space):
        if self.verbose > 0:
            print("Wrapping into a VecTransposeImage")
        env = VecTransposeImage(env)

    return env

私有方法

read_hyperparameters 读取超参数

读取对应模型的超参数文件("hyperparams/{self.algo}.yml")。使用 yaml 进行装载。

返回值有两个,都是解析完的超参数:

  • hyperparams
  • saved_hyperparams:用于存储的超参数

_preprocess_hyperparams 超参数预处理

超参数的定义参见 RL Baselines3 Zoo 超参数

  1. 首先执行 _preprocess_schedules 处理 learning_rate、clip_range、clip_range_vf 这三个参数,处理结果还存在 hyperparams 里面。
  2. 设置执行步长状态 n_timesteps,如果外界有传入则用外界的(override),否则用超参数文件里的
  3. 处理 normalization
  4. 处理策略、缓存相关超参数(policy_kwargs、replay_buffer_class、replay_buffer_kwargs),直接用 eval
  5. 删除超参数 key,以便能够传入模型构造函数(n_envs、n_timesteps、frame_stack)
  6. 将超参数封装成了一个类

_preprocess_schedules 调度预处理

在超参数中寻找 learning_rate、clip_range、clip_range_vf。

分两种情况,如果是字符串:

  • 格式分两段,下划线间隔,第一部分是 schedule,第二部分是 initial_value。
  • 对初始值通过 linear_schedule 封装,并替换掉超参数中的原值。
  • (schedule 没用上,应该是这部分还没开发完。)

如果是数值:

  • 这里包了一个 constant_fn
  • 定义在 from stable_baselines3.common.utils import constant_fn

_preprocess_normalization

只有超参数里面设置了 normalize 才进行处理,否则不处理。

如果设置了 gamma,还要把它存到 normalize_kwargs 里面。

get_callback_list 获取回调列表

根据超参数中设置的 callback 创建回调。

举例:

# 单个回调
callback: stable_baselines3.common.callbacks.CheckpointCallback

# 多个回调
callback:
    - utils.callbacks.PlotActionWrapper
    - stable_baselines3.common.callbacks.CheckpointCallback

如果没指定,就返回一个空数组。